import torch
import os

weght_ce = []
class_name = ['1', '2', '3', '4', '5', '6', '7']
data_dir = '../rafdb/train/'
for i in (class_name):
  
  class_dir = data_dir+i+'/'
  nums = len(os.listdir(class_dir))
  weght_ce.append(nums)
  # print(nums)
  print(i,nums)

# print(weght_ce)
# weght_ce = [900,15000,800]
# weght_ce = [2313, 145, 153, 601, 7457, 9968, 3386, 3409]
weights = torch.tensor(weght_ce, dtype=torch.float32)
weights = weights / weights.sum()
weights = 1.0 / weights
weights = weights / weights.sum()
print(weights)


